from network import *
from generator import *
dataset_dir = '/home/aistudio/work/'
if __name__ == "__main__":
    model = base_model(34)
    model.compile(SGD(momentum=0.9, lr = 0.001), loss='categorical_crossentropy', metrics=['accuracy'])
    lib1 = "../C++/preprocess.so"
    lib2 = "../C++/preprocess2.so"
    train_gen = batch_generator("train_discard.index", "train_discard.csv", data_arg_on = True, libname = lib1)
    print("train_gen init over")
    val_gen = batch_generator("val_discard.index", "val_discard.csv", libname = lib2)
    print(len(train_gen))
    callbacks = [ModelCheckpoint("base_model.hdf5", save_best_only = True), CSVLogger("base_model_59.csv")]
    history = model.fit_generator(generator = train_gen, 
        steps_per_epoch = len(train_gen), 
        epochs = 1000, 
        validation_data = val_gen, 
        validation_steps = len(val_gen), 
        callbacks = callbacks)